package com.kjs.web.jwt.wrapper;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.Enumeration;
import java.util.Map;
import java.util.Vector;

/**
 * 请求参数的包裹
 * 添加自定义参数到接口中
 */
public class ParameterRequestWrapper extends HttpServletRequestWrapper {
    private Map<String, String[]> params; // 新增的参数

    public ParameterRequestWrapper(HttpServletRequest request, Map<String, String[]> params){
        super(request);
        this.params = params;
        if (params != null && !params.isEmpty()){
            rebuildRequest(request); // 重构一个请求对象
        }
    }
    @Override
    public String getParameter(String name){
        String result = "";

        Object v = params.get(name);
        if (v == null) {
            result = null;
        } else if (v instanceof String[]) {
            String[] strArr = (String[]) v;
            if (strArr.length > 0) {
                result = strArr[0];
            } else {
                result = null;
            }
        } else if (v instanceof String) {
            result = (String) v;
        } else {
            result = v.toString();
        }

        return result;
    }
    @Override
    public Map<String, String[]> getParameterMap() {
        return params;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return new Vector<>(params.keySet()).elements();
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] result;

        Object v = params.get(name);
        if (v == null) {
            result = null;
        } else if (v instanceof String[]) {
            result = (String[]) v;
        } else if (v instanceof String) {
            result = new String[] { (String) v };
        } else {
            result = new String[] { v.toString() };
        }

        return result;
    }
    private void rebuildRequest(HttpServletRequest req) {
        String queryString = req.getQueryString();
        if (queryString != null && queryString.trim().length() > 0) {
            String[] params = queryString.split("&");

            for (String param : params) {
                int splitIndex = param.indexOf("=");
                if (splitIndex == -1) {
                    continue;
                }

                String key = param.substring(0, splitIndex);
                if (!this.params.containsKey(key)) {
                    if (splitIndex < param.length()) {
                        String value = param.substring(splitIndex + 1);
                        this.params.put(key, new String[] { value });
                    }
                }
            }
        }
    }
}
